Pytorch Dataloader参数及源码详解 您所在的位置:网站首页 pytorch dataset详解 Pytorch Dataloader参数及源码详解

Pytorch Dataloader参数及源码详解

2023-11-26 18:00| 来源: 网络整理| 查看: 265

文章目录 Dataloader

Dataloader

首先看一下最基础的DataLoader的源码中__next__的实现。为方便理解我只选取了num_works为0的情况(num_works简单理解就是能够并行化地读取数据)。

class _BaseDataLoaderIter(object): def __init__(self, loader: DataLoader) -> None: self._dataset = loader.dataset self._dataset_kind = loader._dataset_kind self._IterableDataset_len_called = loader._IterableDataset_len_called self._auto_collation = loader._auto_collation self._drop_last = loader.drop_last self._index_sampler = loader._index_sampler self._num_workers = loader.num_workers self._prefetch_factor = loader.prefetch_factor self._pin_memory = loader.pin_memory and torch.cuda.is_available() self._timeout = loader.timeout self._collate_fn = loader.collate_fn self._sampler_iter = iter(self._index_sampler) self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item() self._persistent_workers = loader.persistent_workers self._num_yielded = 0 def __iter__(self) -> '_BaseDataLoaderIter': return self def _reset(self, loader, first_iter=False): self._sampler_iter = iter(self._index_sampler) self._num_yielded = 0 self._IterableDataset_len_called = loader._IterableDataset_len_called def _next_index(self): return next(self._sampler_iter) # may raise StopIteration def _next_data(self): raise NotImplementedError # 最重要的 def __next__(self) -> Any: if self._sampler_iter is None: self._reset() data = self._next_data() self._num_yielded += 1 if self._dataset_kind == _DatasetKind.Iterable and \ self._IterableDataset_len_called is not None and \ self._num_yielded > self._IterableDataset_len_called: warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} " "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called, self._num_yielded) if self._num_workers > 0: warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the " "IterableDataset replica at each worker. Please see " "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.") warnings.warn(warn_msg) return data next = __next__ # Python 2 compatibility def __len__(self) -> int: return len(self._index_sampler) def __getstate__(self): # TODO: add limited pickling support for sharing an iterator # across multiple threads for HOGWILD. # Probably the best way to do this is by moving the sample pushing # to a separate thread and then just sharing the data queue # but signalling the end is tricky without a non-blocking API raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)

在阅读上面代码前,我们可以假设我们的数据是一组图像,每一张图像对应一个index,那么如果我们要读取数据就只需要对应的index即可,即上面代码中的indices,而选取index的方式有多种,有按顺序的,也有乱序的,所以这个工作需要Sampler完成,现在你不需要具体的细节,后面会介绍,你只需要知道DataLoader和Sampler在这里产生关系。

那么Dataset和DataLoader在什么时候产生关系呢?没错就是下面一行。我们已经拿到了indices,那么下一步我们只需要根据index对数据进行读取即可了。

再下面的if语句的作用简单理解就是,如果pin_memory=True,那么Pytorch会采取一系列操作把数据拷贝到GPU,总之就是为了加速。

首先来看DataLoader的所有参数。 看DataLoader类,首先明确一个很重要的概念就是: sampler是索引,dataset是数据,DataLoader类中的__iter__方法实现了按照sampler索引从dataset中取值的功能。 DataLoader类透过多个函数,才从索引拿到数据,这个顺序是(以单进程为例): DataLoader.iter() -> DataLoader._get_iterator() -> _SingleProcessDataLoaderIter(DataLoader实例) -> _SingleProcessDataLoaderIter._next_data() 接下来就逐一的看这些函数的具体实现。

class DataLoader(Generic[T_co]): ... # 省略 def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, shuffle: bool = False, sampler: Optional[Sampler[int]] = None, batch_sampler: Optional[Sampler[Sequence[int]]] = None, num_workers: int = 0, collate_fn: _collate_fn_t = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: _worker_init_fn_t = None, multiprocessing_context=None, generator=None, *, prefetch_factor: int = 2, persistent_workers: bool = False):

首先来看下DataLoader类中的__iter__方法:

class DataLoader(Generic[T_co]): ... # 省略 def __iter__(self) -> '_BaseDataLoaderIter': # When using a single worker the returned iterator should be # created everytime to avoid reseting its state # However, in the case of a multiple workers iterator # the iterator is only created once in the lifetime of the # DataLoader object so that workers can be reused if self.persistent_workers and self.num_workers > 0: if self._iterator is None: self._iterator = self._get_iterator() else: self._iterator._reset(self) return self._iterator else: return self._get_iterator() # 假定只有一个workder负责读入数据,则进入这个函数 def _get_iterator(self) -> '_BaseDataLoaderIter': if self.num_workers == 0: return _SingleProcessDataLoaderIter(self) # 只考虑一个进程读入;将DataLoader实例作为参数传入 else: return _MultiProcessingDataLoaderIter(self) class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): def __init__(self, loader): # DataLoader作为参数传入 super(_SingleProcessDataLoaderIter, self).__init__(loader) # 以DataLoader作为参数,初始化父类 assert self._timeout == 0 assert self._num_workers == 0 self._dataset_fetcher = _DatasetKind.create_fetcher( self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last) def _next_data(self): # self._next_index()是_BaseDataLoaderIter的方法,看_BaseDataLoaderIter的源码可以知道, # _next_index()实际上调用了传入的DataLoader的'loader._index_sampler',我们稍后去看DataLoader的这个方法 index = self._next_index() # 根据index取到了真正的数据 data = self._dataset_fetcher.fetch(index) # may raise StopIteration if self._pin_memory: data = _utils.pin_memory.pin_memory(data) return data class DataLoader(Generic[T_co]): r""" Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset. The :class:`~torch.utils.data.DataLoader` supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning. See :py:mod:`torch.utils.data` documentation page for more details. Arguments: dataset (Dataset): dataset from which to load the data. batch_size (int, optional): how many samples per batch to load (default: ``1``). shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch (default: ``False``). sampler (Sampler or Iterable, optional): defines the strategy to draw samples from the dataset. Can be any ``Iterable`` with ``__len__`` implemented. If specified, :attr:`shuffle` must not be specified. batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but returns a batch of indices at a time. Mutually exclusive with :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`. num_workers (int, optional): how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process. (default: ``0``) collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset. pin_memory (bool, optional): If ``True``, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, see the example below. drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: ``False``) timeout (numeric, optional): if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: ``0``) worker_init_fn (callable, optional): If not ``None``, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: ``None``) prefetch_factor (int, optional, keyword-only arg): Number of sample loaded in advance by each worker. ``2`` means there will be a total of 2 * num_workers samples prefetched across all workers. (default: ``2``) persistent_workers (bool, optional): If ``True``, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the workers `Dataset` instances alive. (default: ``False``) .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an unpicklable object, e.g., a lambda function. See :ref:`multiprocessing-best-practices` on more details related to multiprocessing in PyTorch. .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used. When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`, it instead returns an estimate based on ``len(dataset) / batch_size``, with proper rounding depending on :attr:`drop_last`, regardless of multi-process loading configurations. This represents the best guess PyTorch can make because PyTorch trusts user :attr:`dataset` code in correctly handling multi-process loading to avoid duplicate data. However, if sharding results in multiple workers having incomplete last batches, this estimate can still be inaccurate, because (1) an otherwise complete batch can be broken into multiple ones and (2) more than one batch worth of samples can be dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such cases in general. See `Dataset Types`_ for more details on these two types of datasets and how :class:`~torch.utils.data.IterableDataset` interacts with `Multi-process data loading`_. """ dataset: Dataset[T_co] batch_size: Optional[int] num_workers: int pin_memory: bool drop_last: bool timeout: float sampler: Sampler prefetch_factor: int _iterator : Optional['_BaseDataLoaderIter'] __initialized = False def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, shuffle: bool = False, sampler: Optional[Sampler[int]] = None, batch_sampler: Optional[Sampler[Sequence[int]]] = None, num_workers: int = 0, collate_fn: _collate_fn_t = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: _worker_init_fn_t = None, multiprocessing_context=None, generator=None, *, prefetch_factor: int = 2, persistent_workers: bool = False): torch._C._log_api_usage_once("python.data_loader") # type: ignore if num_workers 0 if persistent_workers and num_workers == 0: raise ValueError('persistent_workers option needs num_workers > 0') self.dataset = dataset self.num_workers = num_workers self.prefetch_factor = prefetch_factor self.pin_memory = pin_memory self.timeout = timeout self.worker_init_fn = worker_init_fn self.multiprocessing_context = multiprocessing_context # Arg-check dataset related before checking samplers because we want to # tell users that iterable-style datasets are incompatible with custom # samplers first, so that they don't learn that this combo doesn't work # after spending time fixing the custom sampler errors. if isinstance(dataset, IterableDataset): self._dataset_kind = _DatasetKind.Iterable # NOTE [ Custom Samplers and IterableDataset ] # # `IterableDataset` does not support custom `batch_sampler` or # `sampler` since the key is irrelevant (unless we support # generator-style dataset one day...). # # For `sampler`, we always create a dummy sampler. This is an # infinite sampler even when the dataset may have an implemented # finite `__len__` because in multi-process data loading, naive # settings will return duplicated data (which may be desired), and # thus using a sampler with length matching that of dataset will # cause data lost (you may have duplicates of the first couple # batches, but never see anything afterwards). Therefore, # `Iterabledataset` always uses an infinite sampler, an instance of # `_InfiniteConstantSampler` defined above. # # A custom `batch_sampler` essentially only controls the batch size. # However, it is unclear how useful it would be since an iterable-style # dataset can handle that within itself. Moreover, it is pointless # in multi-process data loading as the assignment order of batches # to workers is an implementation detail so users can not control # how to batchify each worker's iterable. Thus, we disable this # option. If this turns out to be useful in future, we can re-enable # this, and support custom samplers that specify the assignments to # specific workers. if shuffle is not False: raise ValueError( "DataLoader with IterableDataset: expected unspecified " "shuffle option, but got shuffle={}".format(shuffle)) elif sampler is not None: # See NOTE [ Custom Samplers and IterableDataset ] raise ValueError( "DataLoader with IterableDataset: expected unspecified " "sampler option, but got sampler={}".format(sampler)) elif batch_sampler is not None: # See NOTE [ Custom Samplers and IterableDataset ] raise ValueError( "DataLoader with IterableDataset: expected unspecified " "batch_sampler option, but got batch_sampler={}".format(batch_sampler)) else: self._dataset_kind = _DatasetKind.Map if sampler is not None and shuffle: raise ValueError('sampler option is mutually exclusive with ' 'shuffle') if batch_sampler is not None: # auto_collation with custom batch_sampler if batch_size != 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler option is mutually exclusive ' 'with batch_size, shuffle, sampler, and ' 'drop_last') batch_size = None drop_last = False elif batch_size is None: # no auto_collation if drop_last: raise ValueError('batch_size=None option disables auto-batching ' 'and is mutually exclusive with drop_last') if sampler is None: # give default samplers if self._dataset_kind == _DatasetKind.Iterable: # See NOTE [ Custom Samplers and IterableDataset ] sampler = _InfiniteConstantSampler() else: # map-style if shuffle: # Cannot statically verify that dataset is Sized # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] sampler = RandomSampler(dataset, generator=generator) # type: ignore else: sampler = SequentialSampler(dataset) if batch_size is not None and batch_sampler is None: # auto_collation without custom batch_sampler batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.batch_size = batch_size self.drop_last = drop_last self.sampler = sampler self.batch_sampler = batch_sampler self.generator = generator if collate_fn is None: if self._auto_collation: collate_fn = _utils.collate.default_collate else: collate_fn = _utils.collate.default_convert self.collate_fn = collate_fn self.persistent_workers = persistent_workers self.__initialized = True self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ] self._iterator = None def _get_iterator(self) -> '_BaseDataLoaderIter': if self.num_workers == 0: return _SingleProcessDataLoaderIter(self) else: return _MultiProcessingDataLoaderIter(self) @property def multiprocessing_context(self): return self.__multiprocessing_context @multiprocessing_context.setter def multiprocessing_context(self, multiprocessing_context): if multiprocessing_context is not None: if self.num_workers > 0: if not multiprocessing._supports_context: raise ValueError('multiprocessing_context relies on Python >= 3.4, with ' 'support for different start methods') if isinstance(multiprocessing_context, string_classes): valid_start_methods = multiprocessing.get_all_start_methods() if multiprocessing_context not in valid_start_methods: raise ValueError( ('multiprocessing_context option ' 'should specify a valid start method in {!r}, but got ' 'multiprocessing_context={!r}').format(valid_start_methods, multiprocessing_context)) # error: Argument 1 to "get_context" has incompatible type "Union[str, bytes]"; expected "str" [arg-type] multiprocessing_context = multiprocessing.get_context(multiprocessing_context) # type: ignore if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext): raise TypeError(('multiprocessing_context option should be a valid context ' 'object or a string specifying the start method, but got ' 'multiprocessing_context={}').format(multiprocessing_context)) else: raise ValueError(('multiprocessing_context can only be used with ' 'multi-process loading (num_workers > 0), but got ' 'num_workers={}').format(self.num_workers)) self.__multiprocessing_context = multiprocessing_context def __setattr__(self, attr, val): if self.__initialized and attr in ( 'batch_size', 'batch_sampler', 'sampler', 'drop_last', 'dataset', 'persistent_workers'): raise ValueError('{} attribute should not be set after {} is ' 'initialized'.format(attr, self.__class__.__name__)) super(DataLoader, self).__setattr__(attr, val) # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up # since '_BaseDataLoaderIter' references 'DataLoader'. def __iter__(self) -> '_BaseDataLoaderIter': # When using a single worker the returned iterator should be # created everytime to avoid reseting its state # However, in the case of a multiple workers iterator # the iterator is only created once in the lifetime of the # DataLoader object so that workers can be reused if self.persistent_workers and self.num_workers > 0: if self._iterator is None: self._iterator = self._get_iterator() else: self._iterator._reset(self) return self._iterator else: return self._get_iterator() @property def _auto_collation(self): return self.batch_sampler is not None @property def _index_sampler(self): # The actual sampler used for generating indices for `_DatasetFetcher` # (see _utils/fetch.py) to read data at each time. This would be # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise. # We can't change `.sampler` and `.batch_sampler` attributes for BC # reasons. if self._auto_collation: return self.batch_sampler else: return self.sampler def __len__(self) -> int: if self._dataset_kind == _DatasetKind.Iterable: # NOTE [ IterableDataset and __len__ ] # # For `IterableDataset`, `__len__` could be inaccurate when one naively # does multi-processing data loading, since the samples will be duplicated. # However, no real use case should be actually using that behavior, so # it should count as a user error. We should generally trust user # code to do the proper thing (e.g., configure each replica differently # in `__iter__`), and give us the correct `__len__` if they choose to # implement it (this will still throw if the dataset does not implement # a `__len__`). # # To provide a further warning, we track if `__len__` was called on the # `DataLoader`, save the returned value in `self._len_called`, and warn # if the iterator ends up yielding more than this number of samples. # Cannot statically verify that dataset is Sized length = self._IterableDataset_len_called = len(self.dataset) # type: ignore if self.batch_size is not None: # IterableDataset doesn't allow custom sampler or batch_sampler from math import ceil if self.drop_last: length = length // self.batch_size else: length = ceil(length / self.batch_size) return length else: return len(self._index_sampler)


【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有